/*
 *  Copyright 2008-2013 NVIDIA Corporation
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <thrust/detail/internal_functional.h>
#include <thrust/detail/temporary_array.h>
#include <thrust/detail/type_traits.h>
#include <thrust/detail/type_traits/iterator/is_output_iterator.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/scan.h>
#include <thrust/scatter.h>
#include <thrust/transform.h>

#include <cuda/std/iterator>
#include <cuda/std/limits>

THRUST_NAMESPACE_BEGIN
namespace system
{
namespace detail
{
namespace generic
{
namespace detail
{

template <typename ValueType, typename TailFlagType, typename AssociativeOperator>
struct reduce_by_key_functor
{
  AssociativeOperator binary_op;

  using result_type = typename thrust::tuple<ValueType, TailFlagType>;

  _CCCL_HOST_DEVICE reduce_by_key_functor(AssociativeOperator _binary_op)
      : binary_op(_binary_op)
  {}

  _CCCL_HOST_DEVICE result_type operator()(result_type a, result_type b)
  {
    return result_type(thrust::get<1>(b) ? thrust::get<0>(b) : binary_op(thrust::get<0>(a), thrust::get<0>(b)),
                       thrust::get<1>(a) | thrust::get<1>(b));
  }
};

} // end namespace detail

template <typename ExecutionPolicy,
          typename InputIterator1,
          typename InputIterator2,
          typename OutputIterator1,
          typename OutputIterator2,
          typename BinaryPredicate,
          typename BinaryFunction>
_CCCL_HOST_DEVICE thrust::pair<OutputIterator1, OutputIterator2> reduce_by_key(
  thrust::execution_policy<ExecutionPolicy>& exec,
  InputIterator1 keys_first,
  InputIterator1 keys_last,
  InputIterator2 values_first,
  OutputIterator1 keys_output,
  OutputIterator2 values_output,
  BinaryPredicate binary_pred,
  BinaryFunction binary_op)
{
  using difference_type = thrust::detail::it_difference_t<InputIterator1>;

  using FlagType = unsigned int; // TODO use difference_type

  // Use the input iterator's value type per https://wg21.link/P0571
  using ValueType = thrust::detail::it_value_t<InputIterator2>;

  if (keys_first == keys_last)
  {
    return thrust::make_pair(keys_output, values_output);
  }

  // input size
  difference_type n = keys_last - keys_first;

  InputIterator2 values_last = values_first + n;

  // compute head flags
  thrust::detail::temporary_array<FlagType, ExecutionPolicy> head_flags(exec, n);
  thrust::transform(
    exec, keys_first, keys_last - 1, keys_first + 1, head_flags.begin() + 1, ::cuda::std::not_fn(binary_pred));
  head_flags[0] = 1;

  // compute tail flags
  thrust::detail::temporary_array<FlagType, ExecutionPolicy> tail_flags(exec, n); // COPY INSTEAD OF TRANSFORM
  thrust::transform(
    exec, keys_first, keys_last - 1, keys_first + 1, tail_flags.begin(), ::cuda::std::not_fn(binary_pred));
  tail_flags[n - 1] = 1;

  // scan the values by flag
  thrust::detail::temporary_array<ValueType, ExecutionPolicy> scanned_values(exec, n);
  thrust::detail::temporary_array<FlagType, ExecutionPolicy> scanned_tail_flags(exec, n);

  thrust::inclusive_scan(
    exec,
    thrust::make_zip_iterator(values_first, head_flags.begin()),
    thrust::make_zip_iterator(values_last, head_flags.end()),
    thrust::make_zip_iterator(scanned_values.begin(), scanned_tail_flags.begin()),
    detail::reduce_by_key_functor<ValueType, FlagType, BinaryFunction>(binary_op));

  thrust::exclusive_scan(
    exec, tail_flags.begin(), tail_flags.end(), scanned_tail_flags.begin(), FlagType(0), ::cuda::std::plus<FlagType>());

  // number of unique keys
  FlagType N = scanned_tail_flags[n - 1] + 1;

  // scatter the keys and accumulated values
  thrust::scatter_if(exec, keys_first, keys_last, scanned_tail_flags.begin(), head_flags.begin(), keys_output);
  thrust::scatter_if(
    exec, scanned_values.begin(), scanned_values.end(), scanned_tail_flags.begin(), tail_flags.begin(), values_output);

  return thrust::make_pair(keys_output + N, values_output + N);
} // end reduce_by_key()

template <typename ExecutionPolicy,
          typename InputIterator1,
          typename InputIterator2,
          typename OutputIterator1,
          typename OutputIterator2>
_CCCL_HOST_DEVICE thrust::pair<OutputIterator1, OutputIterator2> reduce_by_key(
  thrust::execution_policy<ExecutionPolicy>& exec,
  InputIterator1 keys_first,
  InputIterator1 keys_last,
  InputIterator2 values_first,
  OutputIterator1 keys_output,
  OutputIterator2 values_output)
{
  using KeyType = thrust::detail::it_value_t<InputIterator1>;

  // use equal_to<KeyType> as default BinaryPredicate
  return thrust::reduce_by_key(
    exec, keys_first, keys_last, values_first, keys_output, values_output, ::cuda::std::equal_to<KeyType>());
} // end reduce_by_key()

template <typename ExecutionPolicy,
          typename InputIterator1,
          typename InputIterator2,
          typename OutputIterator1,
          typename OutputIterator2,
          typename BinaryPredicate>
_CCCL_HOST_DEVICE thrust::pair<OutputIterator1, OutputIterator2> reduce_by_key(
  thrust::execution_policy<ExecutionPolicy>& exec,
  InputIterator1 keys_first,
  InputIterator1 keys_last,
  InputIterator2 values_first,
  OutputIterator1 keys_output,
  OutputIterator2 values_output,
  BinaryPredicate binary_pred)
{
  using T = ::cuda::std::

    _If<thrust::detail::is_output_iterator<OutputIterator2>,
        thrust::detail::it_value_t<InputIterator2>,
        thrust::detail::it_value_t<OutputIterator2>>;

  // use plus<T> as default BinaryFunction
  return thrust::reduce_by_key(
    exec, keys_first, keys_last, values_first, keys_output, values_output, binary_pred, ::cuda::std::plus<T>());
} // end reduce_by_key()

} // end namespace generic
} // end namespace detail
} // end namespace system
THRUST_NAMESPACE_END
